{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FPjTnCG-v2Z6",
        "outputId": "553badc6-ed9d-469c-ec74-f8606018ae23"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "zsh:1: command not found: nvidia-smi\n"
          ]
        }
      ],
      "source": [
        "\"\"\"\n",
        "Notebook for training the embedding model for the Rossler system.\n",
        "Since this is not a built in example, we will need to implement our our config,\n",
        "model and data handler.\n",
        "=====\n",
        "Distributed by: Notre Dame SCAI Lab (MIT Liscense)\n",
        "- Associated publication:\n",
        "url: https://arxiv.org/abs/2010.03957\n",
        "doi: \n",
        "github: https://github.com/zabaras/transformer-physx\n",
        "=====\n",
        "\"\"\"\n",
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "34XMtg9FZFql"
      },
      "source": [
        "# Environment Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[1m\u001b[36mdata\u001b[m\u001b[m/                            train_rossler_enn.py\n",
            "\u001b[1m\u001b[36moutputs\u001b[m\u001b[m/                         train_rossler_transformer.ipynb\n",
            "train_rossler_enn.ipynb          train_rossler_transformer.py\n"
          ]
        }
      ],
      "source": [
        "%ls"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6MK_h8wF0Rr4"
      },
      "source": [
        "Now lets download the training and validation data for the lorenz system. Info on wget from [Google drive](https://stackoverflow.com/questions/37453841/download-a-file-from-google-drive-using-wget). This will eventually be update to zenodo repo."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2NtZ02zD0EKo",
        "outputId": "3d9cdf22-4d9a-4fa8-bf3d-a5b16091d30b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "mkdir: data: File exists\n"
          ]
        }
      ],
      "source": [
        "!mkdir data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dCHiYnrdZN95"
      },
      "source": [
        "# Transformer-PhysX Rossler System\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SoNLzt1xuQk4",
        "outputId": "40e039ef-2c2e-4905-ce41-e5d28f62b10e"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "import os\n",
        "import logging\n",
        "import h5py\n",
        "import paddle\n",
        "import paddle.nn as nn\n",
        "import numpy as np\n",
        "\n",
        "from typing import Dict, List, Tuple\n",
        "\n",
        "# Torch imports\n",
        "from paddle.io import DataLoader, Dataset\n",
        "from paddle.optimizer.lr import ExponentialDecay\n",
        "\n",
        "# Trphysx imports\n",
        "from trphysx.config.configuration_auto import AutoPhysConfig\n",
        "from trphysx.embedding.training import (\n",
        "    EmbeddingParser,EmbeddingTrainer\n",
        ")\n",
        "from trphysx.embedding.embedding_auto import AutoEmbeddingModel\n",
        "from trphysx.embedding.training import AutoDataHandler\n",
        "\n",
        "logger = logging.getLogger(__name__)\n",
        "\n",
        "argv = []\n",
        "argv = argv + [\"--exp_name\", \"rossler\"]\n",
        "argv = argv + [\"--training_h5_file\", \"./data/rossler_training.hdf5\"]\n",
        "argv = argv + [\"--eval_h5_file\", \"./data/rossler_valid.hdf5\"]\n",
        "argv = argv + [\"--stride\", \"16\"]\n",
        "argv = argv + [\"--batch_size\", \"256\"]\n",
        "argv = argv + [\"--block_size\", \"16\"]\n",
        "argv = argv + [\"--n_train\", \"1024\"]\n",
        "argv = argv + [\"--n_eval\", \"32\"]\n",
        "argv = argv + [\"--epochs\", \"100\"]\n",
        "\n",
        "# Setup logging\n",
        "logging.basicConfig(\n",
        "    format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n",
        "    datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
        "    level=logging.INFO,\n",
        ")\n",
        "\n",
        "args = EmbeddingParser().parse(argv)\n",
        "args.device = paddle.set_device(\"gpu:0\" if paddle.is_compiled_with_cuda() else \"cpu\")\n",
        "logger.info(\"Torch device: {}\".format(args.device))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T9Fel9vqZVsH"
      },
      "source": [
        "## Initializing Datasets and Models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JjxiqD3lF98_"
      },
      "source": [
        "Now we can use the auto classes to initialized the predefined configs, dataloaders and models. This may take a bit!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EcpC9Fy243RN",
        "outputId": "d8d5a891-d6a1-4d3d-afa1-13fbcb65d9c5"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/Users/puqing/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/dygraph/math_op_patch.py:275: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float64, but right dtype is paddle.float32, the right dtype will convert to paddle.float64\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "config = AutoPhysConfig.load_config(args.exp_name)\n",
        "\n",
        "dataloader = AutoDataHandler.load_data_handler(args.exp_name)\n",
        "\n",
        "# Set up data-loaders\n",
        "training_loader = dataloader.createTrainingLoader(\n",
        "    args.training_h5_file,\n",
        "    block_size=args.block_size,\n",
        "    stride=args.stride,\n",
        "    ndata=args.n_train,\n",
        "    batch_size=args.batch_size,\n",
        ")\n",
        "\n",
        "testing_loader = dataloader.createTestingLoader(\n",
        "    args.eval_h5_file, block_size=32, ndata=args.n_eval, batch_size=8\n",
        ")\n",
        "\n",
        "# Load configuration file then init model\n",
        "model = AutoEmbeddingModel.init_trainer(args.exp_name, config).to(args.device)\n",
        "\n",
        "if args.epoch_start > 1:\n",
        "    model.load_model(args.ckpt_dir, args.epoch_start)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "auDiMVZ5UfNz"
      },
      "source": [
        "Initialize optimizer and scheduler. Feel free to change if you want to experiment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "bnrtuKdhGuWQ"
      },
      "outputs": [],
      "source": [
        "scheduler = ExponentialDecay(\n",
        "    learning_rate=args.lr * 0.995 ** (args.epoch_start), gamma=0.995\n",
        ")\n",
        "\n",
        "optimizer = paddle.optimizer.Adam(\n",
        "    parameters=model.parameters(),\n",
        "    learning_rate=scheduler,\n",
        "    weight_decay=1e-8,\n",
        "    grad_clip=paddle.nn.ClipGradByGlobalNorm(0.1),\n",
        ")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "It_LLGQIZe0b"
      },
      "source": [
        "## Training the Embedding Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U2XPKfYTUuXf"
      },
      "source": [
        "Train the model. No visualization here, just boring numbers. This notebook only trains for 100 epochs for brevity, feel free to train longer. The test loss here is only the recovery loss MSE(x - decode(encode(x))) and does not reflect the quality of the Koopman dynamics."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2ic9DFQcUWpm",
        "outputId": "b8949287-465f-49c7-9c08-5a4ea401c045"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/Users/puqing/opt/miniconda3/lib/python3.9/site-packages/paddle/fluid/dygraph/math_op_patch.py:275: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float64, but right dtype is paddle.float32, the right dtype will convert to paddle.float64\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "trainer = EmbeddingTrainer(model, args, (optimizer, scheduler))\n",
        "trainer.train(training_loader, testing_loader)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h3ZEE4ixh8nr"
      },
      "source": [
        "Check your Google drive for checkpoints."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "train_rossler_enn.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "base",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.12"
    },
    "vscode": {
      "interpreter": {
        "hash": "db64661fe0d122184d8f0ece1104b953994d7acbc475dabbdd4eb4bc24907e06"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
